{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using Natural Gradient Descent with Variational Models\n",
    "\n",
    "## Overview\n",
    "\n",
    "In this notebook, we'll demonstrate how to use **natural gradient descent** when optimizing variational GPyTorch models. This will be the same as the [SVGP regression notebook](./SVGP_Regression_CUDA.ipynb), except we will be using a different optimizer.\n",
    "\n",
    "## What is natural gradient descent (NGD)?\n",
    "\n",
    "Without going into too much detail, using SGD or Adam isn't the best way to optimize the parameters of variational Gaussian distributions. Essentially, SGD takes steps assuming that the loss geometry of the parameters is Euclidean. This is a bad assumption for the parameters of many distributions, especially Gaussians. See [Agustinus Kristiadi's blog post](https://wiseodd.github.io/techblog/2018/03/14/natural-gradient/) for more detail.\n",
    "\n",
    "Instead, it makes more sense to take gradient steps that have been scaled better for the geometry of the variational distribution's parameters. Specifically, if $\\mathbf m$ and $\\mathbf S$ are the mean and covariance of our variational Gaussian posterior approximation, then we will achieve faster convergence if we take the following steps:\n",
    "\n",
    "\\begin{align}\n",
    "\\begin{bmatrix} \\mathbf m \\\\ \\mathbf S \\end{bmatrix}\n",
    "\\leftarrow\n",
    "\\begin{bmatrix} \\mathbf m \\\\ \\mathbf S \\end{bmatrix}\n",
    "-\n",
    "\\alpha \\mathcal{\\mathbf F}^{-1} \\nabla \\begin{bmatrix} \\mathbf m \\\\   \\mathbf S \\end{bmatrix}\n",
    "\\end{align}\n",
    "\n",
    "where $\\alpha$ is a step size and $\\mathbf F$ is the **Fisher information matrix** corresponding to this distribution. This is known as **natural gradient descent**, or **NGD**.\n",
    "\n",
    "It turns out that for Gaussian distributions (and, more broadly, for all distributions in the exponential family), there are efficient update equations for NGD. See the following papers for more information:\n",
    "- [Salimbeni, Hugh, Stefanos Eleftheriadis, and James Hensman. \"Natural gradients in practice: Non-conjugate variational inference in gaussian process models.\" AISTATS (2018).](https://arxiv.org/abs/1803.09151)\n",
    "- [Hensman, James, Magnus Rattray, and Neil D. Lawrence. \"Fast variational inference in the conjugate exponential family.\" NeurIPS (2012).](http://papers.nips.cc/paper/4766-fast-variational-inference-in-the-conjugate-exponential-family)\n",
    "\n",
    "## Jointly optimizing variational parameters/hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tqdm\n",
    "import math\n",
    "import torch\n",
    "import gpytorch\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# Make plots inline\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For this example notebook, we'll be using the `elevators` UCI dataset used in the paper. Running the next cell downloads a copy of the dataset that has already been scaled and normalized appropriately. For this notebook, we'll simply be splitting the data using the first 80% of the data as training and the last 20% as testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import urllib.request\n",
    "import os\n",
    "from scipy.io import loadmat\n",
    "from math import floor\n",
    "\n",
    "\n",
    "# this is for running the notebook in our testing framework\n",
    "smoke_test = ('CI' in os.environ)\n",
    "\n",
    "\n",
    "if not smoke_test and not os.path.isfile('../elevators.mat'):\n",
    "    print('Downloading \\'elevators\\' UCI dataset...')\n",
    "    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', '../elevators.mat')\n",
    "\n",
    "\n",
    "if smoke_test:  # this is for running the notebook in our testing framework\n",
    "    X, y = torch.randn(1000, 3), torch.randn(1000)\n",
    "else:\n",
    "    data = torch.Tensor(loadmat('../elevators.mat')['data'])\n",
    "    X = data[:, :-1]\n",
    "    X = X - X.min(0)[0]\n",
    "    X = 2 * (X / X.max(0)[0]) - 1\n",
    "    y = data[:, -1]\n",
    "\n",
    "\n",
    "train_n = int(floor(0.8 * len(X)))\n",
    "train_x = X[:train_n, :].contiguous()\n",
    "train_y = y[:train_n].contiguous()\n",
    "\n",
    "test_x = X[train_n:, :].contiguous()\n",
    "test_y = y[train_n:].contiguous()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following steps create the dataloader objects. See the [SVGP regression notebook](./SVGP_Regression_CUDA.ipynb) for details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "train_dataset = TensorDataset(train_x, train_y)\n",
    "train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)\n",
    "\n",
    "test_dataset = TensorDataset(test_x, test_y)\n",
    "test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SVGP models for NGD\n",
    "\n",
    "There are **three** key differences between NGD-SVGP models and the standard SVGP models from the [SVGP regression notebook](./SVGP_Regression_CUDA.ipynb).\n",
    "\n",
    "\n",
    "### Difference #1: NaturalVariationalDistribution\n",
    "\n",
    "Rather than using `gpytorch.variational.CholeskyVarationalDistribution` (or other variational distribution objects), you have to use one of the two objects:\n",
    "\n",
    "- `gpytorch.variational.NaturalVariationalDistribution` (typically faster optimization convergence, but less stable for non-conjugate likelihoods)\n",
    "- `gpytorch.variational.TrilNaturalVariationalDistribution` (typically slower optimization convergence, but more stable for non-conjugate likelihoods)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPModel(gpytorch.models.ApproximateGP):\n",
    "    def __init__(self, inducing_points):\n",
    "        variational_distribution = gpytorch.variational.NaturalVariationalDistribution(inducing_points.size(0))\n",
    "        variational_strategy = gpytorch.variational.VariationalStrategy(\n",
    "            self, inducing_points, variational_distribution, learn_inducing_locations=True\n",
    "        )\n",
    "        super(GPModel, self).__init__(variational_strategy)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "        \n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "\n",
    "\n",
    "inducing_points = train_x[:500, :]\n",
    "model = GPModel(inducing_points=inducing_points)\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model = model.cuda()\n",
    "    likelihood = likelihood.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Difference #2: Two optimizers - one for the variational parameters; one for the hyperparameters\n",
    "\n",
    "NGD steps only update the variational parameters. Therefore, we need two separate optimizers: one for the variational parameters (using NGD) and one for the other hyperparameters (using Adam or whatever you want).\n",
    "\n",
    "Some things to note about the NGD variational optimizer:\n",
    "\n",
    "- **You must use `gpytorch.optim.NGD` as the variational NGD optimizer!** Adaptive gradient algorithms will mess up the natural gradient steps. (Any stochastic optimizer works for the hyperparameters.)\n",
    "- **Use a large learning rate for the variational optimizer.** Typically, 0.1 is a good learning rate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "variational_ngd_optimizer = gpytorch.optim.NGD(model.variational_parameters(), num_data=train_y.size(0), lr=0.1)\n",
    "\n",
    "hyperparameter_optimizer = torch.optim.Adam([\n",
    "    {'params': model.hyperparameters()},\n",
    "    {'params': likelihood.parameters()},\n",
    "], lr=0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Difference #3: The updated training loop\n",
    "\n",
    "In the training loop, we have to update both optimizers (`variational_ngd_optimizer` and `hyperparameter_optimizer`). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "126c78ce830a4352aba355851d55ad5a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=13.0, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=13.0, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=13.0, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=13.0, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model.train()\n",
    "likelihood.train()\n",
    "mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))\n",
    "\n",
    "num_epochs = 1 if smoke_test else 4\n",
    "epochs_iter = tqdm.notebook.tqdm(range(num_epochs), desc=\"Epoch\")\n",
    "for i in epochs_iter:\n",
    "    minibatch_iter = tqdm.notebook.tqdm(train_loader, desc=\"Minibatch\", leave=False)\n",
    "    \n",
    "    for x_batch, y_batch in minibatch_iter:\n",
    "        ### Perform NGD step to optimize variational parameters\n",
    "        variational_ngd_optimizer.zero_grad()\n",
    "        hyperparameter_optimizer.zero_grad()\n",
    "        output = model(x_batch)\n",
    "        loss = -mll(output, y_batch)\n",
    "        minibatch_iter.set_postfix(loss=loss.item())\n",
    "        loss.backward()\n",
    "        variational_ngd_optimizer.step()\n",
    "        hyperparameter_optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You could also modify the optimization loop to alternate between the NGD step and the hyperparameter updates:\n",
    "\n",
    "```python\n",
    "    for x_batch, y_batch in minibatch_iter:\n",
    "        ### Perform NGD step to optimize variational parameters\n",
    "        variational_ngd_optimizer.zero_grad()\n",
    "        output = model(x_batch)\n",
    "        loss = -mll(output, y_batch)\n",
    "        minibatch_iter.set_postfix(loss=loss.item())\n",
    "        loss.backward()\n",
    "        variational_ngd_optimizer.step()\n",
    "        \n",
    "        ### Perform Adam step to optimize hyperparameters\n",
    "        hyperparameter_optimizer.zero_grad()\n",
    "        output = model(x_batch)\n",
    "        loss = -mll(output, y_batch)\n",
    "        loss.backward()\n",
    "        hyperparameter_optimizer.step()\n",
    "```\n",
    "\n",
    "### Evaluation\n",
    "\n",
    "That's it! This model should converge must faster/better than the standard SVGP model - and it can often get better performance (especially when using more inducing points)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test MAE: 0.07514254748821259\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "likelihood.eval()\n",
    "means = torch.tensor([0.])\n",
    "with torch.no_grad():\n",
    "    for x_batch, y_batch in test_loader:\n",
    "        preds = model(x_batch)\n",
    "        means = torch.cat([means, preds.mean.cpu()])\n",
    "means = means[1:]\n",
    "print('Test MAE: {}'.format(torch.mean(torch.abs(means - test_y.cpu()))))"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
